#%%
import numpy as np
import utils
from numpy.random import default_rng
random = default_rng(0)
from matplotlib import pyplot as plt
from scipy.linalg import block_diag
from policy import TraceNormAgent

#%% Graph construction
n_clusters= 2
n_nodes = 3
dim = 10

cluster_inds = utils.create_cluster_indices(n_points= n_nodes, n_clusters=  n_clusters, 
                     
                                            weights= random.uniform(0,1, size=n_clusters), 
                                            imbalance= 0.0)

# plt.figure()
# plt.imshow(Adj)
# plt.show()
#%

# def random_weighted_adjacency(Adj):
#     weight_mat = random.uniform(0,1, size= Adj.shape)
#     res = Adj * weight_mat
#     res = np.triu(res) + np.tril(res.T)
#     return res

# plt.show()
Theta_true_cluster = random.standard_normal((n_clusters,dim))
Theta_true_cluster /= np.linalg.norm(Theta_true_cluster, axis=1, keepdims= True)

Theta_true = Theta_true_cluster[cluster_inds]

#%%
dtype= 'float64'
X = []
y = []
XX = []
yX = []
y_norm_sum = 0.0
t = 0
sigma = 0.01
n_obs_min, n_obs_max = 5, 15
for i in range(n_nodes):
    n_observations = random.integers(n_obs_min, n_obs_max)
    t += n_observations
    X_temp = random.standard_normal((n_observations, dim))
    X_temp /= np.linalg.norm(X_temp, axis=1, keepdims= True)
    X.append(X_temp)
    y_temp = X_temp @ Theta_true[i] + sigma*random.standard_normal(n_observations)
    y.append(y_temp)
    XX.append(X_temp.T @ X_temp)
    yX.append(y_temp @ X_temp)
    y_norm_sum += np.dot(y_temp,y_temp)

XX = np.array(XX, dtype= dtype)#/t
yX = np.array(yX, dtype= dtype)#/t
agent = TraceNormAgent()
agent.set_pb_params(A=XX, b=yX, l_nuc_t= 1.0)
# %% Initialize solvers 

Theta_sol = agent.solve_mult_reg_nuc_verbose(gamma=1.01)
print(Theta_sol-Theta_true)
plt.semilogy(agent.costs-np.min(agent.costs))
plt.show()